#!/usr/bin/python3
# This file is part of Cockpit.
#
# Copyright (C) 2018 Red Hat, Inc.
#
# Cockpit is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation; either version 2.1 of the License, or
# (at your option) any later version.
#
# Cockpit is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Cockpit; If not, see <http://www.gnu.org/licenses/>.

import os
import sys
import json
import time
import base64
import subprocess

COCKPIT_SSH_COMMAND = "/usr/libexec/cockpit-ssh"

# This command is meant to be used as an authentication command launched
# by cockpit-ws. It asks for authentication from cockpit-ws and expects
# to receive a basic auth header in response. If COCKPIT_SSH_KEY_PATH is set
# we will try to decrypt the key with the given password. If successful
# we send the decrypted key to cockpit-ws for use with cockpit-ssh.
# Once finished we exec cockpit-ssh to actually establish the ssh connection.
# All communication with cockpit-ws happens on stdin and stdout using the
# cockpit protocol
# (https://github.com/cockpit-project/cockpit/blob/main/doc/protocol.md)


def usage():
    sys.stderr.write("usage {} [user@]host[:port]\n".format(sys.argv[0]))
    sys.exit(os.EX_USAGE)


def send_auth_command(challenge, response):
    cmd = {
        "command": "authorize",
    }

    if challenge:
        cmd["cookie"] = "session{}{}".format(os.getpid(), time.time())
        cmd["challenge"] = challenge
    if response:
        cmd["response"] = response

    text = json.dumps(cmd).encode('utf-8')
    os.write(1, "{}\n\n".format(len(text) + 1).encode('utf-8'))
    os.write(1, text)


def send_problem_init(problem, message, auth_methods):
    cmd = {
        "command": "init",
        "problem": problem
    }

    if message:
        cmd["message"] = message

    if auth_methods:
        cmd["auth-method-results"] = auth_methods

    text = json.dumps(cmd).encode('utf-8')
    os.write(1, "{}\n\n".format(len(text) + 1).encode('utf-8'))
    os.write(1, text)


def read_size(fd):
    sep = b'\n'
    size = 0
    seen = 0

    while True:
        t = os.read(fd, 1)

        if not t:
            return 0

        if t == sep:
            break

        size = (size * 10) + int(t)
        seen = seen + 1

        if seen > 7:
            raise ValueError("Invalid frame: size too long")

    return size


def read_frame(fd):
    size = read_size(fd)

    data = ""
    while size > 0:
        d = os.read(fd, size)
        size = size - len(d)
        data = data + d.decode('utf8')

    return data


def read_auth_reply():
    data = read_frame(1)
    cmd = json.loads(data)
    response = cmd.get("response")
    if cmd.get("command") != "authorize" or \
       not cmd.get("cookie") or not response:
        raise ValueError("Did not receive a valid authorize command")

    return response


def decode_basic_header(response):
    starts = "Basic "

    assert response
    assert response.startswith(starts), response

    val = base64.b64decode(response[len(starts):].encode('utf-8')).decode("utf-8")
    user, password = val.split(':', 1)
    return user, password


def send_decrypted_key(fname, password):
    r, w = os.pipe()
    os.set_inheritable(r, True)
    os.set_inheritable(w, True)
    p = subprocess.Popen(["openssl", "rsa", "-in", fname, "-passin", "fd:{}".format(r)],
                         preexec_fn=lambda: os.close(w),
                         pass_fds=(r,),
                         stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    os.write(w, password.encode('utf-8'))
    os.close(w)
    os.close(r)

    data, err = p.communicate()
    if p.returncode == 0:
        send_auth_command(None, "private-key {}".format(data.decode('utf-8')))
        return True
    else:
        sys.stderr.write("Couldn't open private key: {}".format(err))
        return False


def main(args):
    if len(args) != 2:
        usage()

    host = args[1]
    key_name = os.environ.get("COCKPIT_SSH_KEY_PATH")
    if key_name:
        send_auth_command("*", None)
        try:
            resp = read_auth_reply()
            user, password = decode_basic_header(resp)
        except (ValueError, TypeError, AssertionError) as e:
            send_problem_init("internal-error", str(e), {})
            raise

        if send_decrypted_key(key_name, password):
            host = "{}@{}".format(user, host)
        else:
            send_problem_init("authentication-failed", "Couldn't open private key",
                              {"password": "denied"})
            return

    os.execlpe(COCKPIT_SSH_COMMAND, COCKPIT_SSH_COMMAND, host, os.environ)


if __name__ == '__main__':
    main(sys.argv)
